{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial on \"Decollate a batch\"\n",
    "\n",
    "# `What is decollate?`\n",
    "\n",
    "`decollate batch` is a highlight feature in MONAI v0.6, which simplifies the post processing transforms and provides flexible following operations on a batch of data with various data shape.\n",
    "\n",
    "1. As a preprocessing step in a regular PyTorch program, we usually apply transforms to each input item and `collate` the processed data into a mini-batch (via a PyTorch dataloader with a `collate_fn`). The 'batched' data are used in the rest of the workflow, e.g. for the model forward, training loss computation steps:\n",
    "\n",
    "![image](../figures/collate_batch.png)\n",
    "\n",
    "2. As of MONAI v0.6, we recommand a decollating operation as the first postprocessing step, to convert a 'batched' data (e.g. model predictions) into a list of tensors. \n",
    "The typical logic from `decollate batch`:\n",
    "\n",
    "![image](../figures/decollate_batch.png)\n",
    "\n",
    "\n",
    "## `Why decollate?`\n",
    "The benefits of this 'decollating' operation are:\n",
    "\n",
    "(1) we can execute postprocessing transforms for each item in the output mini-batch respectively, some randomised transforms could be applied with different randomised behaviour for each prediction independently.\n",
    "\n",
    "(2) Both the preprocessing and postprocessing transforms only need to support `channel-first` shape of input data. this simplifies the transform API design, and reduces input validation burdens.\n",
    "\n",
    "(3) It allows to apply `Invertd` transform for the predictions and the inverted data can have different shape, because they are in a list, not stacked in a single batch tensor anymore.\n",
    "\n",
    "(4) All the MONAI metrics can support both `batch-first` tensor and list of `channel-first` tensors, so we can compute metrics for the inverted data (potentially in different data shape) directly.\n",
    "\n",
    "\n",
    "\n",
    "## `How to decollate?`\n",
    "\n",
    "The rest of the tutorial shows a detailed example program that executes a typical `collate batch` and `decollate batch` workflows.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/decollate_batch.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, tqdm]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import shutil\n",
    "import tempfile\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "from glob import glob\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.data import create_test_image_3d, Dataset, DataLoader, decollate_batch\n",
    "from monai.handlers import from_engine\n",
    "from monai.inferers import sliding_window_inference\n",
    "from monai.metrics import DiceMetric\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import (\n",
    "    Activationsd,\n",
    "    EnsureChannelFirstd,\n",
    "    EnsureTyped,\n",
    "    AsDiscreted,\n",
    "    Compose,\n",
    "    Invertd,\n",
    "    LoadImaged,\n",
    "    Orientationd,\n",
    "    Resized,\n",
    "    SaveImaged,\n",
    "    ScaleIntensityd,\n",
    ")\n",
    "from monai.utils import set_determinism\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified, a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set determinism, logging, device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_determinism(seed=0)\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "device = torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate random (image, label) pairs\n",
    "\n",
    "Generate 5 `image` and `label` pairs for this evaluation task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"img{i:d}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i:d}.nii.gz\"))\n",
    "\n",
    "images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n",
    "segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "files = [{\"img\": img, \"seg\": seg} for img, seg in zip(images, segs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup preprocessing transforms, dataset, dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessing = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"img\", \"seg\"]),\n",
    "        EnsureChannelFirstd(keys=[\"img\", \"seg\"]),\n",
    "        Orientationd(keys=\"img\", axcodes=\"RAS\"),\n",
    "        Resized(keys=\"img\", spatial_size=(96, 96, 96), mode=\"trilinear\", align_corners=True),\n",
    "        ScaleIntensityd(keys=\"img\"),\n",
    "        EnsureTyped(keys=[\"img\", \"seg\"]),\n",
    "    ]\n",
    ")\n",
    "dataset = Dataset(data=files, transform=preprocessing)\n",
    "dataloader = DataLoader(dataset, batch_size=1, num_workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup postprocessing transforms, metrics\n",
    "Here we try to invert the preprocessing predictions for `pred` and save into Nifti files.\n",
    "\n",
    "As all the post processing transforms expect `Tensor` input, apply `EnsureTyped` first to ensure the data type after `decollate_batch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "postprocessing = Compose(\n",
    "    [\n",
    "        EnsureTyped(keys=[\"pred\", \"seg\"]),  # ensure Tensor type after `decollate`\n",
    "        Activationsd(keys=\"pred\", sigmoid=True),\n",
    "        Invertd(\n",
    "            keys=\"pred\",  # invert the `pred` data field, also support multiple fields\n",
    "            transform=preprocessing,\n",
    "            orig_keys=\"img\",  # get the previously applied pre_transforms information on the `img` data field\n",
    "            meta_keys=\"pred_meta_dict\",  # key field to save inverted meta data, every item maps to `keys`\n",
    "            orig_meta_keys=\"img_meta_dict\",  # use the meta data from `img_meta_dict` field when inverting\n",
    "            nearest_interp=False,  # don't change the interpolation mode of preprocessing when inverting\n",
    "            to_tensor=True,\n",
    "            device=device,\n",
    "        ),\n",
    "        AsDiscreted(keys=\"pred\", threshold=0.5),\n",
    "        SaveImaged(keys=\"pred\", meta_keys=\"pred_meta_dict\", output_dir=root_dir, resample=False),\n",
    "    ]\n",
    ")\n",
    "# will compute mean dice on the decollated `predictions` and `labels`, which are list of `channel-first` tensors\n",
    "dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute the evaluation progress with all the above components\n",
    "Here we use a randomly initialized `UNet` to execute evaluation, usually we load a pretrained weights in the real-world practice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ").to(device)\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for data in dataloader:\n",
    "        images, labels = data[\"img\"].to(device), data[\"seg\"].to(device)\n",
    "        # define sliding window size and batch size for windows inference\n",
    "        roi_size = (64, 64, 64)\n",
    "        sw_batch_size = 4\n",
    "        data[\"pred\"] = sliding_window_inference(images, roi_size, sw_batch_size, model)\n",
    "        data[\"seg\"] = labels\n",
    "\n",
    "        # decollate the batch data into list of dictionaries, every dictionary maps to an input data\n",
    "        data = [postprocessing(i) for i in decollate_batch(data)]\n",
    "        # extract a list of `prections` and a list of `labels` with the `from_engine` utility\n",
    "        pred, y = from_engine([\"pred\", \"seg\"])(data)\n",
    "        # compute mean dice for current iteration\n",
    "        dice_metric(y_pred=pred, y=y)\n",
    "    # aggregate the final mean dice result\n",
    "    print(f\"evaluation metric: {dice_metric.aggregate().item()}\")\n",
    "    # reset the metric status\n",
    "    dice_metric.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
